import numpy as np
import matplotlib.pyplot as plt
import torch
from warnings import filterwarnings
import torchkbnufft as tkbn

# Centered FFT over specified axes using PyTorch
def fftc(x, axes=None):
    complex_x = x.type(torch.complex64) 
    x_shifted = torch.fft.ifftshift(complex_x, dim=axes)
    X_shifted = torch.fft.fftn(x_shifted, dim=axes, norm='ortho')
    X = torch.fft.fftshift(X_shifted, dim=axes)
    return X

# Centered IFFT over specified axes using PyTorch
def ifftc(x, axes=None):
    x_shifted = torch.fft.ifftshift(x, dim=axes)
    X_shifted = torch.fft.ifftn(x_shifted, dim=axes, norm='ortho')
    X = torch.fft.fftshift(X_shifted, dim=axes)
    
    magnitude = torch.abs(X)
    return magnitude

def noiser(data, sigma = 0.05):
    """
    Add Gaussian noise to input data for measurement robustness.

    Args:
        data: Input tensor to be noised

    Returns:
        Noised tensor with added Gaussian noise (sigma)
    """
    return data + torch.randn_like(data, device=data.device) * sigma

def forward(data, mask, is_nufft, acceleration, dtype):
    """
    Apply forward measurement operation based on specified task.

    Args:
        data: Input image tensor

    Returns:
        Processed image according to task
    """
    data = data / 2 + 0.5 # Normalize data to [0, 1]
    
    if is_nufft is False :        
        fftc_data = fftc(data, axes = (-2, -1)) #complex64
        undersampled_fftc_data = fftc_data * mask
        undersampled_data = ifftc(undersampled_fftc_data, axes = (-2, -1)).type(dtype)
        undersampled_data = torch.clamp(undersampled_data, 0, 1)
    else :
        # NUFFT
        im_size = data.shape[-2:]
        # grid_size = ((int)(2 * im_size[0]), (int)(2 * im_size[1]))
        
        if len(data.shape) == 3:
            data = data.unsqueeze(0)
        assert len(data.shape) == 4, "Input data must be 4D tensor"
        data = data.type(torch.complex64) 
        
        # create NUFFT objects, use 'ortho' for orthogonal FFTs
        nufft_ob = tkbn.KbNufft(
            im_size=im_size,
            # grid_size=grid_size
        ).to(data)
        adjnufft_ob = tkbn.KbNufftAdjoint(
            im_size=im_size,
            # grid_size=grid_size
        ).to(data)
        
        mask = (mask - torch.min(mask)) / (torch.max(mask) - torch.min(mask)) * (torch.pi - (-torch.pi)) + (-torch.pi)
        
        # calculate k-space data
        kdata = nufft_ob(data, mask, norm='ortho')
        
        # calculate image from k-space data
        # dcomp = tkbn.calc_density_compensation_function(ktraj=mask, im_size=im_size)
        # undersampled_data = adjnufft_ob(kdata * dcomp, mask, norm='ortho') * acceleration
        undersampled_data = adjnufft_ob(kdata, mask, norm='ortho')
        undersampled_data = torch.abs(undersampled_data).type(dtype)
        undersampled_data = undersampled_data * 4
        undersampled_data = torch.clamp(undersampled_data, 0, 1)
        undersampled_data = undersampled_data.squeeze(0)
        
    undersampled_data = undersampled_data * 2 - 1 # Normalize data to [-1, 1]

    return undersampled_data